
#include "program.h"
#include "propositional.h"
#include "fitness.h"
#include "bottom.h"

namespace lp {

	literal_map make_literal_map(const Functor& bottom)
	{
		literal_map lm;
		int k = 1;
		for (auto i = bottom.body_begin(); i != bottom.body_end(); ++i,++k) {
			lm.insert(std::make_pair(i->signature(),std::make_pair(i,k)));
		}
		return lm;
	}

	void build_constraint_literals(
		std::vector<Functor>::const_iterator lit,
		std::vector<Functor>::const_iterator end,
		int q_at,
		const literal_map& bmap,
		ClauseDB& db,
		Subst& subs,
		std::map<id_type,id_type>& var2fun,
		std::vector<int>& sass)
	{
		auto get_fun = [&](id_type id){
			auto at = var2fun.find(id);
			if (at == var2fun.end()) at = var2fun.insert(std::make_pair(id,fmap.new_fun_id())).first;
			return at->second;
		};

		if (lit == end) {
			// Success
			PClause pcl;
			for (int i : sass) pcl.insert(i);
			db.insert(std::move(pcl));
		} else {
			const Functor& f = *lit;
			// Search range rng, but only up to i->second (left of q)
			auto rng = bmap.equal_range(f.signature());
			bool found = false;
			for (auto j = rng.first; j != rng.second; ++j) {
				if (j->second.second >= q_at) {
					assert( all_of(j,rng.second,[&](const decltype(*j)& p){ return p.second.second >= q_at; }) );
					break;
				}
				std::set<id_type> vars;
				// Prepare subs with X -> _x bindings to ground bottom clause
				for (const Functor& n : *j->second.first) { 
					if (n.is_variable() && !subs.get(n.id())) {
						subs.steal(n.id(),new Functor(get_fun(n.id())));
						vars.insert(n.id());
					}
				}
				// Unify
				if (unify(&f,&*j->second.first,subs,vars)) {
					//std::cerr << "  PMatched " << f << " with " << *j->second.first << " : " << subs << "\n";
					sass.push_back(-j->second.second);
					build_constraint_literals(std::next(lit),end,q_at,bmap,db,subs,var2fun,sass);
					sass.pop_back();
				}
				// Undo subs
				for (id_type i : vars) subs.erase(i);
			}
		}
	}

	void add_implication_constraints(
		const literal_map& bmap,
		const implication& ip,
		ClauseDB& db)
	{
		const Functor& q = ip.cons;
		const auto& ante = ip.ante;
		std::map<id_type,id_type> var2fun; // TODO: reuse this mapping
		auto get_fun = [&](id_type id){
			auto at = var2fun.find(id);
			if (at == var2fun.end()) at = var2fun.insert(std::make_pair(id,fmap.new_fun_id())).first;
			return at->second;
		};

		// Find all q
		auto q_rng = bmap.equal_range(q.signature());
		for (auto i = q_rng.first; i != q_rng.second; ++i) {
			Subst subs;
			std::set<id_type> vars;
			// Prepare subs with X -> _x bindings to ground bottom clause
			for (const Functor& n : *i->second.first) { 
				if (n.is_variable() && !subs.get(n.id())) {
					subs.steal(n.id(),new Functor(get_fun(n.id())));
					vars.insert(n.id());
				}
			}
			// Try to unify
			if (unify(&q,&*i->second.first,subs)) {
				//std::cerr << "QMatched " << q << " with " << *i->second.first << " : " << subs << "\n";
				std::vector<int> sass;
				sass.push_back(-i->second.second);
				build_constraint_literals(ante.begin(),ante.end(),i->second.second,bmap,db,subs,var2fun,sass);
			} // for each literal matching Q
			// note: no need to undo subs since it goes out of scope here
		} // for each q/n among bottom literals 
	}


	stat Program::generalize_nrsample(const sign_t& sign)
	{
		// Cache parameters
		const int param_csize = params.force_int(parameters::csize);
		const int param_noise = params.force_int(parameters::noise);
		const bool param_prune_visit = params.is_set(parameters::prune_visit);
		const bool param_prune_consistent = params.is_set(parameters::prune_consistent);
		const bool param_prune_inconsistent = params.is_set(parameters::prune_inconsistent);
		// Which SAT solver to use?
		enum class sat_solver_t { DPLL, DUAL_HORN };
		sat_solver_t sat_solver = sat_solver_t::DPLL;
		if (params[parameters::sat_solver].is_atom()) {
			std::string s = params[parameters::sat_solver].get_atom();
			std::for_each(s.begin(),s.end(),std::tolower);
			if (s == "dpll") {
				sat_solver = sat_solver_t::DPLL;
				DEBUG_INFO(std::cout << "Using DPLL-based SAT Solver\n");
			} else if (s == "dual" || s == "dualhorn" || s == "dual_horn" || s == "dual horn") {
				sat_solver = sat_solver_t::DUAL_HORN;
				DEBUG_INFO(std::cout << "Using Dual Horn Clause SAT Solver\n");
			} else {
				DEBUG_WARNING(std::cerr << "Warning: unrecognized SAT solver option: " << s << ", defaulting to DPLL\n");
			}
		} else {
			DEBUG_WARNING(std::cerr << "Warning: SAT solver option is not string: " << params[parameters::sat_solver] << ", defaulting to DPLL\n");
		}

		//std::cerr << "Generalize_nrsample\n";
		// Create constraints based on modes
		Vass vproto;
		ClauseDB allowed;
		int call_count = 0;
		PClause last_cl; // last inserted clause
		// Get DPLL selection strategy
		unique_ptr<DPLL_Selection> select_ptr = get_dpll_selection(params);
		if (!select_ptr) {
			select_ptr.reset(new Simplicity_selection());
		}
		// Pick fitness function
		auto fitness = pick_fitness_function<bsf_type>(params);
		if (!fitness) {
			// Set default: lex
			fitness.reset(new fitness_lex<bsf_type>(params.force_int(parameters::terminate)));
		}

#ifndef NDEBUG
		vector<bitstring> history;
#endif

		auto init_search = [&]
		(Program& thy, Functor& bottom, int blen, const vector<Mode>& modes, 
			list<clause>& pex, list<Program::clause>& nex, Constraints& constr, stat& sts) -> bsf_type
		{
#ifndef NDEBUG
			history.clear();
#endif
			//std::cerr << "Calling init...\n";
			fitness->init(bottom,modes,thy,pex,nex,sts[stat::exc]);
			// Create prototype of model to store all settings
			vproto = Vass(bottom,modes,thy.params.force_int(parameters::csize),thy.params.force_int(parameters::dpllass));
			call_count = 0;
			last_cl.clear();
			allowed = ClauseDB(bottom,blen,modes,constr,thy.params);
			if (params.is_set(parameters::conditional_constraints)) {
				for (const auto& ip : thy.imp) {
					//std::cerr << "Adding implication constraints for " << ip.first.size() << " => " << ip.second << "\n";
					add_implication_constraints(make_literal_map(bottom),ip,allowed);
					//std::cerr << "Done\n";
				}
			}
			DEBUG_INFO(cerr << "Bottom Clause Constraints: " << allowed.size() << "\n");
			return bsf_type();
		};

		auto find_candidate = [&]
		(Program& thy, list<clause>& pex, list<Program::clause>& nex, Functor& bottom, int bsize, 
			const vector<Mode>& modes, const bsf_type& bsf, Constraints& constr, stat& sts,
			deadline_t deadline) -> bsf_type
		{
			// Get solution from DPLL
			Vass va = vproto;
			try {
				//std::cerr << "Constraints: " << allowed << "\n";
				select_ptr->init(allowed,va,bsize);
				int res = allowed.simplify(va,bsize);
				if (res) {
					switch (sat_solver) {
					case sat_solver_t::DPLL:
						res = allowed.dpll(va,bsize,*select_ptr,deadline);
						break;
					case sat_solver_t::DUAL_HORN:
						res = allowed.dual_horn_solve(va,bsize);
						break;
					}
				}

				if (res <= 0) {
					DEBUG_INFO(cout << "Exhausted search space\n");
					throw search_completed();
				}
			} catch (max_assignments) {
				DEBUG_CRUCIAL(std::cerr << "Warning: maximal number of DPLL assignments reached\n");
				throw search_aborted();
			} catch (time_out) {
				DEBUG_CRUCIAL(std::cerr << "Warning: SAT solver timed out\n");
				throw search_aborted();
			}

			// DEBUGGING: check that we never revisit a point and respect csize
#ifndef NDEBUG
			const auto bs_tmp = make_bitstring(va);
			assert(find(history.begin(),history.end(),bs_tmp) == history.end());
			history.push_back(bs_tmp);
			assert( count(bs_tmp.begin(),bs_tmp.end(),1) <= param_csize );
#endif
			// Make candidate from bottom clause and mask
			bsf_type sol = make_bitstring(va);
			//std::cerr << "Masking bottom, va: " << va << "\n";
			mask_bottom(bottom,sol.mask());
			assert( std::count_if(bottom.body_begin(),bottom.body_end(),[&](const Functor& t){ return !t.is_constant(true_id); }) 
				== std::count(sol.mask().begin(),sol.mask().end(),1) );

			++sts[stat::fitc]; // register fitness call
			//std::cerr << "Evaluating fitness for candidate: " << bottom << "\n";
			fitness->evaluate(bottom,sol,thy,pex,nex,bsf,sts[stat::exc],deadline);
			// Update consistent/inconsistent stats
			sol.is_consistent(param_noise) ? ++sts[stat::conc] : ++sts[stat::inconc];

			//DEBUG_INFO( cout << "Generated Solution: " << sol << "\n" );

			// Update Clause DB
			// Insert singleton constraint if requested
			if (param_prune_visit) {
				PClause cc = make_complement_clause(va,0);
				DEBUG_INFO( cout << "  Revisit Constraint: " << cc << "\n\n" );
				last_cl = cc;
				allowed.insert(move(cc));
			}
			if (sol.is_consistent(param_noise)) { 
				// Sol is Consistent
				if (param_prune_consistent) {
					PClause cc = make_complement_clause(va,1);
					DEBUG_INFO( cout << "  Prune down Constraint: " << cc << "\n\n" );
					last_cl = cc;
					allowed.insert(move(cc));
				}
			} else { 
				// Sol is Inconsistent
				if (param_prune_inconsistent) {
					PClause cc = make_complement_clause(va,-1);
					DEBUG_INFO( cout << "  Prune up Constraint: " << cc << "\n\n" );
					last_cl = cc;
					allowed.insert(move(cc));
				}
			}

			if (allowed.has_empty()) {
				DEBUG_INFO(cout << "Exhausted search space\n");
				throw sol; // last solution
			}

			return sol;
		};

		return this->generalize(sign,init_search,find_candidate);
	}


}


